import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.optim
import json
import torch.utils.data.sampler
import os
import glob
import random
import time

import configs
import backbone
import data.feature_loader as feat_loader
from data.datamgr import SetDataManager
from methods.maml_moml import MAML_MOML
from methods.protonet_moml import ProtoNet_MOML
from methods.boil_moml import BOIL_MOML
from methods.maml_test import MAML_Test
from methods.constrained_meta import Constrained_meta
from methods.constrained_implicit import Constrained_implicit
from io_utils import model_dict, parse_args, get_resume_file, get_best_file , get_assigned_file


if __name__ == '__main__':
    params = parse_args('test')
    acc_all = []
    iter_num = 600

    few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) 
    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(params.device)
        
    if params.method == 'maml_moml':
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model           = MAML_MOML(  model_dict[params.model], approx = (params.method == 'maml_moml') , **few_shot_params )
        model.weighting_mode = params.weighting_mode
        
    elif params.method == 'protonet_moml':
        model = ProtoNet_MOML( model_dict[params.model], **few_shot_params )
        model.weighting_mode = params.weighting_mode
    elif params.method == 'boil_moml':
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = BOIL_MOML( model_dict[params.model], **few_shot_params )
        model.weighting_mode = params.weighting_mode
        
    elif params.method == 'maml_test':
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML_Test( model_dict[params.model], **few_shot_params )
        model.weighting_mode = params.weighting_mode
    elif params.method == 'constrained_implicit':
        backbone.ConvBlock.maml = False
        backbone.SimpleBlock.maml = False
        backbone.BottleneckBlock.maml = False
        backbone.ResNet.maml = False
        model = Constrained_implicit(  model_dict[params.model], **few_shot_params )
        if params.weighting_mode=='SOML' or params.weighting_mode=='COML':
            model.weighting_mode = params.weighting_mode
        else:
            model.weighting_mode = 'COML'
        model.meta_lambda=1.0
        if params.n_shot==1:
            model.meta_lambda=8.0
            if params.dataset== 'CUB':
                model.meta_lambda=7.0
    elif params.method == 'constrained_meta':
        backbone.ConvBlock.maml = False
        backbone.SimpleBlock.maml = False
        backbone.BottleneckBlock.maml = False
        backbone.ResNet.maml = False
        model = Constrained_meta(  model_dict[params.model], **few_shot_params )
        model.weighting_mode = 'COML'
        model.meta_lambda=1.0
        if params.n_shot==1:
            model.meta_lambda=8.0
            if params.dataset== 'CUB':
                model.meta_lambda=7.0

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s_%s' %(configs.save_dir, params.dataset, params.model, params.method, params.weighting_mode, params.mark)
  
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++'] :
        checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)
        
    print(checkpoint_dir)

    #modelfile   = get_resume_file(checkpoint_dir)
    if 'Conv' in params.model:
        image_size = 84 
    else:
        image_size = 224
        
    split = params.split
    
    if params.save_iter == -1:
        modelfile   = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
            datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
            loadfile    = configs.data_dir[params.dataset] + split + '.json'
            novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
            model.eval()
            model.test_loop( novel_loader, return_std = True)
    elif params.save_iter == -2:
        modelfile   = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
            datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
            loadfile    = configs.data_dir[params.dataset] + split + '.json'
            novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
            model.eval()
            model.test_loop( novel_loader, return_std = True)
        for i in range(1000):
            iter_i = i * params.save_freq
            print(iter_i)
            modelfile   = get_assigned_file(checkpoint_dir,iter_i)
            if modelfile is not None:
                tmp = torch.load(modelfile)
                model.load_state_dict(tmp['state'])
                datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
                loadfile    = configs.data_dir[params.dataset] + split + '.json'
                novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
                model.eval()
                model.test_loop( novel_loader, return_std = True)
    else:
        modelfile   = get_assigned_file(checkpoint_dir,params.save_iter)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
            datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
            loadfile    = configs.data_dir[params.dataset] + split + '.json'
            novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
            model.eval()
            model.test_loop( novel_loader, return_std = True)


#    split = params.split
#    if params.save_iter != -1:
#        split_str = split + "_" +str(params.save_iter)
#    else:
#        split_str = split


#    with open('./record/results.txt' , 'a') as f:
#        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 
#        aug_str = '-aug' if params.train_aug else ''
#        aug_str += '-adapted' if params.adaptation else ''
#        if params.method in ['baseline', 'baseline++'] :
#            exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' %(params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.test_n_way )
#        else:
#            exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' %(params.dataset, split_str, params.model, params.method, aug_str , params.n_shot , params.train_n_way, params.test_n_way )
#        acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))
#        f.write( 'Time: %s, Setting: %s, Acc: %s \n' %(timestamp,exp_setting,acc_str)  )
